Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for 2/3/8-bit GPTQ Quantization Models #2330

Merged
merged 8 commits into from
Feb 29, 2024

Conversation

chu-tianxiang
Copy link
Contributor

There's already a pull request supporting varying quantization bit levels for GPTQ models, leveraging kernels from the AutoGPTQ repository. This PR presents an alternative approach inspired by exllamav2.

While exllamav2 doesn't natively support 2&3&8-bit GPTQ models, it possesses the essential components. In essence, EXL2 operates as a mixed-bit GPTQ model, 2&3&8-bit models can be seen as special cases. Although there are minor differences in scales and zero points, these can be easily adjusted.

Following is the comparison of latency and throughput of LLama2-7B under different bit quantization in single A100. 4-bit is the from main branch with cuda-graph fix while 3-bit and 8-bit are newly added. All measured using the benchmark_latency.py and benchmark_throughput.py scripts. (2-bit GPTQ models can hardly generate coherent output and is of no practical value, so I didn't include it below)

bit Single Query Latency Throughput
fp16 93 tokens/s 9.24 requests/s
4 190 tokens/s 7.76 requests/s
3 206 tokens/s 7.80 requests/s
8 135 tokens/s 7.38 requests/s

This has not been tested on ROCm device yet.

@JasonZhu1313
Copy link
Contributor

Hey @chu-tianxiang, what's the request rate / QPS for your throughput test? Any intuition on why we've seen ~2x tokens per second but lower throughput?

@chu-tianxiang
Copy link
Contributor Author

Hey @chu-tianxiang, what's the request rate / QPS for your throughput test? Any intuition on why we've seen ~2x tokens per second but lower throughput?

I used the benchmark_throughput.py which adds all request before running the inference instead of sending at some request rate.
The 2x tokens per second only happens for very low batch size when memory wall matters, while the large batch performance is actually worse than fp16 due to the extra computation cost of dequantization.

@raywanb
Copy link
Contributor

raywanb commented Feb 23, 2024

Hey @chu-tianxiang, can you please update this pr to the latest master branch.

@aliencaocao
Copy link
Contributor

looking forward to this getting merged!

@simon-mo simon-mo requested a review from WoosukKwon February 28, 2024 00:18
@esmeetu
Copy link
Collaborator

esmeetu commented Feb 28, 2024

@chu-tianxiang I have tested this feature using model:https://huggingface.co/TheBloke/WizardCoder-33B-V1.1-GPTQ/tree/gptq-8bit--1g-actorder_True. It’s ok when setting max-model-len=8192, but not with 16384. It will cause illegal memory access. And I have no idea with this. Besides, I change the max_postion_embeddingsin config.json to 4096.

@WoosukKwon WoosukKwon mentioned this pull request Feb 28, 2024
5 tasks
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Awesome! Thanks for the PR and apologies for the delayed review.

@aliencaocao
Copy link
Contributor

Gptq 8 bit doesnt work on v100, cannot compile as require sm80 and above (its marlin and quip kernels). Any plan to fix that since v100 is in the official supported list?

@chu-tianxiang
Copy link
Contributor Author

@esmeetu I tested the model in the link and cannot reproduce the illegal memory access error. Could you please provide more details about the setup and code?

@aliencaocao this PR doesn't include marlin or quip kernels, I guess you're talking about the gptq_hf branch. I'll add a cuda arch guard for those kernels, thanks for the report.

@WoosukKwon WoosukKwon merged commit 01a5d18 into vllm-project:main Feb 29, 2024
22 checks passed
@aliencaocao
Copy link
Contributor

Yes i meant the gptq hf branch. I figured it out myself by removing all quip and marlin codes and it works for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants